# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.incubate.passes import ir


def set_resnet_unit_attrs(resnet_unit, has_shortcut):
    resnet_unit.SetAttr("fuse_add", False)
    resnet_unit.SetAttr("act_type", "relu")
    resnet_unit.SetAttr("has_shortcut", has_shortcut)
    resnet_unit.SetAttr("data_format", 'NHWC')
    resnet_unit.SetAttr("dilation", 1)
    resnet_unit.Attr("stride").MappedPattern(
        op="conv2d", name="strides", element_index=0
    )
    resnet_unit.Attr("padding").MappedPattern(
        op="conv2d", name="paddings", element_index=0
    )
    resnet_unit.Attr("group").MappedPattern(op="conv2d", name="groups")
    resnet_unit.Attr("op_device").MappedPattern(op="conv2d", name="op_device")
    resnet_unit.Attr("op_namescope").MappedPattern(
        op="conv2d", name="op_namescope"
    )
    resnet_unit.Attr("momentum").MappedPattern(op="batch_norm", name="momentum")
    resnet_unit.Attr("epsilon").MappedPattern(op="batch_norm", name="epsilon")
    resnet_unit.Attr("use_global_stats").MappedPattern(
        op="batch_norm", name="use_global_stats"
    )


def set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ=None, varZ=None):
    resnet_unit.SetOutputs(
        RunningMeanX=meanX,
        RunningVarX=varX,
        RunningMeanZ=meanZ,
        RunningVarZ=varZ,
    )


@ir.RegisterPass
def fuse_resnet_unit():
    def pattern_conv_bn(x, filter, scale, bias, mean, var):
        filter.Attr("shape")[0].Mod(32).EQ(0)
        filter.Attr("shape")[1].Mod(8).EQ(0)
        filter.Attr("shape")[2].EQ(1)
        filter.Attr("shape")[3].EQ(1)
        conv2d = ir.PassDesc.OP.conv2d(Input=x, Filter=filter)
        conv2d.SetAttr("data_format", 'NHWC')
        bn = ir.PassDesc.OP.batch_norm(
            X=conv2d, Bias=bias, Mean=mean, Scale=scale, Variance=var
        )
        return bn

    def pattern_one_input(x, filter, scale, bias, mean, var):
        bn = pattern_conv_bn(x, filter, scale, bias, mean, var)
        relu = ir.PassDesc.OP.relu(X=bn.Output("Y"))
        return relu

    def replace_one_input(x, filter, scale, bias, mean, var):
        resnet_unit = ir.PassDesc.OP.resnet_unit(
            X=x, FilterX=filter, ScaleX=scale, BiasX=bias, MeanX=mean, VarX=var
        )
        set_resnet_unit_attrs(resnet_unit, False)
        set_resnet_unit_outputs(resnet_unit, mean, var)
        return resnet_unit.Output("Y")

    def pattern_two_input(
        x,
        filterX,
        scaleX,
        biasX,
        meanX,
        varX,
        z,
        filterZ,
        scaleZ,
        biasZ,
        meanZ,
        varZ,
    ):
        bnX = pattern_conv_bn(x, filterX, scaleX, biasX, meanX, varX)
        bnZ = pattern_conv_bn(z, filterZ, scaleZ, biasZ, meanZ, varZ)
        ewadd = ir.PassDesc.OP.elementwise_add(
            X=bnX.Output("Y"), Y=bnZ.Output("Y")
        )
        relu = ir.PassDesc.OP.relu(X=ewadd)
        return relu

    def replace_two_input(
        x,
        filterX,
        scaleX,
        biasX,
        meanX,
        varX,
        z,
        filterZ,
        scaleZ,
        biasZ,
        meanZ,
        varZ,
    ):
        resnet_unit = ir.PassDesc.OP.resnet_unit(
            X=x,
            FilterX=filterX,
            ScaleX=scaleX,
            BiasX=biasX,
            MeanX=meanX,
            VarX=varX,
            Z=z,
            FilterZ=filterZ,
            ScaleZ=scaleZ,
            BiasZ=biasZ,
            MeanZ=meanZ,
            VarZ=varZ,
        )
        set_resnet_unit_attrs(resnet_unit, True)
        set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ, varZ)
        return resnet_unit.Output("Y")

    return (pattern_one_input, replace_one_input), (
        pattern_two_input,
        replace_two_input,
    )
